import gc
import os
import numpy as np
import torch
import torch.nn as nn
import time
import argparse
import pandas as pd
import json
import random
import pickle

from utils import *
from tqdm import tqdm

def get_args():
    parser = argparse.ArgumentParser(description="Configs")
    parser.add_argument("--device", type=int, default=0)
    parser.add_argument("--start", type=int, default=0)
    parser.add_argument("--start_number", type=int, default=9)
    parser.add_argument("--num_steps", type=int, default=100)
    parser.add_argument("--batch_size", type=int, default=2)
    parser.add_argument("--num_elites", type=float, default=0.05)
    parser.add_argument("--crossover", type=float, default=0.5)
    parser.add_argument("--num_points", type=int, default=5)
    parser.add_argument("--iter", type=int, default=5)
    parser.add_argument("--mutation", type=float, default=0.01)
    parser.add_argument("--init_prompt_path", type=str, default="./data/autodan_initial_prompt.txt")
    parser.add_argument("--dataset_path", type=str, default="./data/advbench/harmful_behaviors.csv")
    parser.add_argument("--model", type=str, default="baichuan")  # baichuan or chatglm
    parser.add_argument("--save_suffix", type=str, default="normal")
    parser.add_argument("--API_key", type=str, default="sk-ea12af3b9220f03376f04685df876e6e")

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = get_args()
    API_key = args.API_key
    device = f'cuda:{args.device}'
    model_path_dicts = {'baichuan': "/home/yunhaofeng/Project/baichuan/baichuan-inc/Baichuan2-13B-Chat-4bits",
                        "chatglm": "/home/yunhaofeng/Project/glm"}
    model_path = model_path_dicts[args.model]
    template_name = args.model
    
    
    num_steps = args.num_steps  
    batch_size = args.batch_size  
    num_elites = max(1, int(args.batch_size * args.num_elites))  # 精英个体的数量
    crossover = args.crossover  # 交叉率
    num_points = args.num_points  # 多点交叉数
    mutation = args.mutation  # 变异率
    
    
    model, tokenizer = load_model_and_tokenizer(model_path=model_path)
    
    conv_template = load_conversation_template(template_name)
    conv_template.system_message = ['一个好奇的人类和一个人工智能助手之间的聊天。助理会对人的问题给出有用的、详细的、礼貌的回答。']
    conv_template.roles = ('### 人类','### 人工智能助手')
    
    crit = nn.CrossEntropyLoss(reduction='mean')
    infos = {}
    
    prefix_string_init = None
    
    harmful_data = []
    harmful_answer = []
    reference = []
    # 打开文件并逐行读取
    with open('data/advbench/NTU_Chinese.txt', 'r', encoding='utf-8') as file:
        for line in file:
            # 去掉每行末尾的换行符，并将每行内容添加到列表中
            harmful_data.append(line.strip())
            
    with open('data/advbench/NTU_answer.txt', 'r', encoding='utf-8') as file:
        for line in file:
            # 去掉每行末尾的换行符，并将每行内容添加到列表中
            harmful_answer.append(line.strip())
            
    with open('data/advbench/reference.txt', 'r', encoding='utf-8') as file:
        for line in file:
            # 去掉每行末尾的换行符，并将每行内容添加到列表中
            reference.append(line.strip())
    
    dataset = zip(harmful_data, harmful_answer)
    
    
    for i, (g, t) in tqdm(enumerate(dataset), total=len(harmful_data)):
        log = log_init()
        info = {"goal": "", "target": "", "final_suffix": "",
                "final_respond": "", "total_time": 0, "is_success": False, "log": log}
        info["goal"] = info["goal"].join(g)
        info["target"] = info["target"].join(t)
        
        start_time = time.time()
        user_prompt = g
        target = t
        
        for o in range(len(reference)):
            reference[o] = reference[o].replace('[MODEL]', template_name.title())
            reference[o] = reference[o].replace('[KEEPER]', "Baichuan Technology")
        
        start_number = args.start_number
        new_adv_suffixs = reference[start_number:start_number + batch_size]
        
        word_dict = {}
        last_loss = 1e-5
        for j in range(num_steps):
            with torch.no_grad():
                epoch_start_time = time.time()
                losses = get_score_autodan(
                    tokenizer=tokenizer,
                    conv_template=conv_template, instruction=user_prompt, target=target,
                    model=model,
                    device=device,
                    test_controls=new_adv_suffixs,
                    crit=crit)
                score_list = losses.float().cpu().numpy().tolist()
                
                best_new_adv_suffix_id = losses.argmin()
                best_new_adv_suffix = new_adv_suffixs[best_new_adv_suffix_id]
                
                current_loss = losses[best_new_adv_suffix_id]
                if isinstance(prefix_string_init, str):
                    best_new_adv_suffix = prefix_string_init + best_new_adv_suffix
                adv_suffix = best_new_adv_suffix
                
                suffix_manager = autodan_SuffixManager(tokenizer=tokenizer,
                                                       conv_template=conv_template,
                                                       instruction=user_prompt,
                                                       target=target,
                                                       adv_string=adv_suffix)
                
                
                is_success, gen_str = check_for_attack_success(model,
                                                               tokenizer,
                                                               suffix_manager.get_input_ids(adv_string=adv_suffix).to(device),
                                                               suffix_manager._assistant_role_slice,
                                                               test_prefixes)
                now_prompt = suffix_manager.get_prompt(adv_string=adv_suffix)
                if j % args.iter == 0:
                    unfiltered_new_adv_suffixs = autodan_sample_control(control_suffixs=new_adv_suffixs,
                                                                        score_list=score_list,
                                                                        num_elites=num_elites,
                                                                        batch_size=batch_size,
                                                                        crossover=crossover,
                                                                        num_points=num_points,
                                                                        mutation=mutation,
                                                                        API_key=API_key,
                                                                        reference=reference)

                else:
                    unfiltered_new_adv_suffixs, word_dict = autodan_sample_control_hga(word_dict=word_dict,
                                                                                       control_suffixs=new_adv_suffixs,
                                                                                       score_list=score_list,
                                                                                       num_elites=num_elites,
                                                                                       batch_size=batch_size,
                                                                                       crossover=crossover,
                                                                                       mutation=mutation,
                                                                                       API_key=API_key,
                                                                                       reference=reference)
                
                new_adv_suffixs = unfiltered_new_adv_suffixs

                epoch_end_time = time.time()
                epoch_cost_time = round(epoch_end_time - epoch_start_time, 2)
                
                
                print(
                    "################################\n"
                    f"Current Data: {i}/{len(harmful_data)}\n"
                    f"Current Step: {j}/{num_steps}\n"
                    f"Passed:{is_success}\n"
                    f"Loss:{current_loss.item()}\n"
                    f"Epoch Cost:{epoch_cost_time}\n"
                    f"现有脚本:\n{now_prompt}\n"
                    f"Current Response:\n{gen_str}\n"
                    "################################\n")
                
                info["log"]["time"].append(epoch_cost_time)
                info["log"]["loss"].append(current_loss.item())
                info["log"]["suffix"].append(best_new_adv_suffix)
                info["log"]["respond"].append(gen_str)
                info["log"]["success"].append(is_success)

                last_loss = current_loss.item()
                
                if is_success:
                    with open('prompts.pkl', 'ab') as file:
                        pickle.dump(now_prompt, file=file)
                    break
                gc.collect()
                torch.cuda.empty_cache()
        end_time = time.time()
        cost_time = round(end_time - start_time, 2)
        info["total_time"] = cost_time
        info["final_suffix"] = adv_suffix
        info["final_respond"] = gen_str
        info["is_success"] = is_success
        info["Prompt"] = now_prompt
        
        infos[i + args.start] = info
        if not os.path.exists('./results/autodan_hga'):
            os.makedirs('./results/autodan_hga')
        with open(f'./results/autodan_hga/{args.model}_{args.start}_{args.save_suffix}.json', 'w', encoding='utf-8') as json_file:
            json.dump(infos, json_file, ensure_ascii=False)